-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Lora] correct lora saving & loading #2655
Conversation
@@ -265,15 +262,6 @@ def save_function(weights, filename): | |||
# Save the model | |||
state_dict = model_to_save.state_dict() | |||
|
|||
# Clean the folder from a previous save |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was a bad copy-paste and is not needed
The documentation is not available anymore as the PR was closed or merged. |
The following code now works without problems: from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.unet.load_attn_procs("PhanAnh/sd-model-finetuned-lora", weight_name="pytorch_model.bin")
pipe.unet.save_attn_procs("./temp")
pipe.unet.load_attn_procs("./temp")
pipe.unet.save_attn_procs("./temp", safe_serialization=True)
pipe.unet.load_attn_procs("./temp")
pipe.unet.save_attn_procs("./temp", weight_name="lora.bin")
pipe.unet.load_attn_procs("./temp", weight_name="lora.bin")
pipe.unet.save_attn_procs("./temp", weight_name="lora.safetensors", safe_serialization=True)
pipe.unet.load_attn_procs("./temp", weight_name="lora.safetensors") |
if (is_safetensors_available() and weight_name is None) or ( | ||
weight_name is not None and weight_name.endswith(".safetensors") | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to add a small comment highlighting that we always look for safetensors first just for easy readability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point
src/diffusers/loaders.py
Outdated
"`weights_name` is deprecated, please use `weight_name` instead.", | ||
take_from=kwargs, | ||
) | ||
print(weight_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks much!
Left minor nits.
* [Lora] correct lora saving & loading * fix final * Apply suggestions from code review
* [Lora] correct lora saving & loading * fix final * Apply suggestions from code review
* [Lora] correct lora saving & loading * fix final * Apply suggestions from code review
Should correct: #2616